Tumor type classification with MRI

MRI data preprocessed here 📒, which is a processed version of this dataset 💼

Description of the dataset

This brain tumor dataset containing 3064 T1-weighted contrast-inhanced images from 233 patients with three kinds of brain tumor: meningioma (708 slices), glioma (1426 slices), and pituitary tumor (930 slices). Due to the file size limit of repository, we split the whole dataset into 4 subsets, and achive them in 4 .zip files with each .zip file containing 766 slices.The 5-fold cross-validation indices are also provided.

This data is organized in matlab data format (.mat file). Each file stores a struct containing the following fields for an image:

  • cjdata.label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor 3️⃣
  • cjdata.PID: patient ID
  • cjdata.image: image data
  • cjdata.tumorBorder: a vector storing the coordinates of discrete points on tumor border.
      For example, [x1, y1, x2, y2,...] in which x1, y1 are planar coordinates on tumor border.
      It was generated by manually delineating the tumor border. So we can use it to generate
      binary image of tumor mask.
  • cjdata.tumorMask: a binary image with 1s indicating tumor region

This data was used in the following paper:

  • Cheng, Jun, et al. "Enhanced Performance of Brain Tumor Classification via Tumor Region Augmentation and Partition." PloS one 10.10 (2015).
  • Cheng, Jun, et al. "Retrieval of Brain Tumors by Adaptive Spatial Pooling and Fisher Vector Representation." PloS one 11.6 (2016). Matlab source codes are available on github https://github.com/chengjun583/brainTumorRetrieval

Imports

In [1]:
from forgebox.imports import *
from tqdm.notebook import tqdm
import pytorch_lightning as pl
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import interact

Config & Locations

In [2]:
DATA = Path("/GCI/brain_mri/")
MATS = DATA/"mats"
NUMPYS = DATA/"npy"
In [3]:
WEIGHTS = DATA/"weights"
WEIGHTS.mkdir(exist_ok = True)

Meta data, tabulated records

A meta data pandas dataframe makes the info about each image

Column informations
  • pid: patient id
  • img: location of the image numpy
  • label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor
  • shape: shape of the image, here we use only the 512x512
  • img_id: original mat file id of the image
In [4]:
df = pd.read_csv(DATA/"meta.csv")
In [5]:
df["img_id"] = df.img.apply(lambda x:int(Path(x).name.split('.')[0]))
df = df.query("shape=='512_512'").sort_values(by=["img_id"]).reset_index(drop=True)
df.sample(10)
Out[5]:
pid img mask boarder label shape img_id
2474 MR033389B /GCI/brain_mri/mats/2490.mat_img.npy /GCI/brain_mri/mats/2490.mat_mask.npy /GCI/brain_mri/mats/2490.mat_bd.npy 2.0 512_512 2490
1063 108858 /GCI/brain_mri/mats/1067.mat_img.npy /GCI/brain_mri/mats/1067.mat_mask.npy /GCI/brain_mri/mats/1067.mat_bd.npy 3.0 512_512 1067
1706 104019 /GCI/brain_mri/mats/1722.mat_img.npy /GCI/brain_mri/mats/1722.mat_mask.npy /GCI/brain_mri/mats/1722.mat_bd.npy 3.0 512_512 1722
2159 MR039473 /GCI/brain_mri/mats/2175.mat_img.npy /GCI/brain_mri/mats/2175.mat_mask.npy /GCI/brain_mri/mats/2175.mat_bd.npy 2.0 512_512 2175
688 99089 /GCI/brain_mri/mats/689.mat_img.npy /GCI/brain_mri/mats/689.mat_mask.npy /GCI/brain_mri/mats/689.mat_bd.npy 1.0 512_512 689
1318 103731 /GCI/brain_mri/mats/1334.mat_img.npy /GCI/brain_mri/mats/1334.mat_mask.npy /GCI/brain_mri/mats/1334.mat_bd.npy 3.0 512_512 1334
2157 MR037458B /GCI/brain_mri/mats/2173.mat_img.npy /GCI/brain_mri/mats/2173.mat_mask.npy /GCI/brain_mri/mats/2173.mat_bd.npy 2.0 512_512 2173
2836 MR048994 /GCI/brain_mri/mats/2852.mat_img.npy /GCI/brain_mri/mats/2852.mat_mask.npy /GCI/brain_mri/mats/2852.mat_bd.npy 2.0 512_512 2852
715 100820 /GCI/brain_mri/mats/716.mat_img.npy /GCI/brain_mri/mats/716.mat_mask.npy /GCI/brain_mri/mats/716.mat_bd.npy 2.0 512_512 716
2711 MR037458C /GCI/brain_mri/mats/2727.mat_img.npy /GCI/brain_mri/mats/2727.mat_mask.npy /GCI/brain_mri/mats/2727.mat_bd.npy 2.0 512_512 2727
In [6]:
df.vc("label")
Out[6]:
label
2.0 1426
3.0 915
1.0 708

Interactive visualization

Visualization helpers

In [7]:
def vis_patient(pid):
    sub_df = df.query(f"pid=='{pid}'").sort_values(by="img_id")
    img_arr = np.stack(list(np.load(i) for i in sub_df.img))\
        .astype(np.float32)/1000
    mask_arr = np.stack(list(np.load(i) for i in sub_df["mask"]))\
        .astype(np.float32)
    @interact
    def show_mri(i = (1,len(img_arr))):
        print(list(sub_df.img)[i-1])
        rgb_arr = np.stack([
          mask_arr[i-1],
          np.clip(img_arr[i-1]-mask_arr[i-1],0.,1.),
          img_arr[i-1],                  
        ], axis=-1)

        # rgb_arr = img_arr[i-1].astype(np.float32)
        # print(rgb_arr[200:230,200:230])
        display(plt.imshow(rgb_arr))

Preview image and mask

In [8]:
vis_patient('100360')

Learning

Dataset function

In [9]:
class mri_data(Dataset):
    def __init__(self, df: pd.DataFrame):
        super().__init__()
        self.df = df.reset_index(drop = True)
    
    def __len__(self):
        return len(self.df)

    def __repr__(self):
        return f"MRI Dataset:\n\t{len(self.df.pid.unique())} patients, {len(self)} slices"

    def __getitem__(self,idx):
        row = dict(self.df.loc[idx])
        img = np.load(row["img"])
        img = img/(img.max())
        return img[None, ...], row['label']-1

def split_by(
    df: pd.DataFrame,
    col: str,
    val_ratio: float=.2
):
    """
    split the train/ valid ratio from the unique value
        of a certain column
        by certain ratio
        
    - col: the certain column
    - val_ratio: certain ratio
    """
    uniques = np.array(list(set(list(df[col]))))
    validation_ids = np.random.choice(
        uniques, size=int(len(uniques)*val_ratio), replace=False)
    val_slice = df[col].isin(validation_ids)
    return df[~val_slice].sample(frac=1.).reset_index(drop=True),\
        df[val_slice].reset_index(drop=True)
In [10]:
train_df, val_df = split_by(df, "pid")
In [11]:
total_ds = mri_data(df)
train_ds = mri_data(train_df)
val_ds = mri_data(val_df)
In [12]:
train_ds, val_ds
Out[12]:
(MRI Dataset:
 	186 patients, 2392 slices,
 MRI Dataset:
 	46 patients, 657 slices)
In [13]:
x, y = train_ds[5]
x.shape, y
Out[13]:
((1, 512, 512), 1.0)

Mean & Standard Variation

Mean & standard variation of the entire dataset, we need them for the preprocessing layer normalization

In [14]:
all_x = []
for i in tqdm(range(len(total_ds))):
    x,yy = total_ds[i]
    all_x.append(np.array([x.mean(), x.std()]))

In [15]:
all_arr = np.array(all_x)
x_mean, x_std = all_arr.mean(0)
x_mean, x_std
Out[15]:
(0.15574614257151373, 0.16054656673109854)
In [16]:
all_arr[:,0].min(), all_arr[:,0].max(),all_arr[:,1].min(), all_arr[:,1].max()
Out[16]:
(0.055120470304629875,
 0.29819597127486247,
 0.08168585048503188,
 0.24941559801719104)

Model structure

Experiments with Unet

In [17]:
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import Conv2dStaticSamePadding
In [18]:
model = EfficientNet.from_pretrained("efficientnet-b5", num_classes=3)
Loaded pretrained weights for efficientnet-b5
In [19]:
model._conv_stem = Conv2dStaticSamePadding(
  1, 48, kernel_size=(3, 3), stride=(2, 2), bias=False, image_size=512
)

Test model pipeline

In [20]:
model(torch.FloatTensor(x)[None,...]).shape
Out[20]:
torch.Size([1, 3])

Lightning Data Module

In [21]:
class PlData(pl.LightningDataModule):
    def __init__(self, train_df, val_df, bs):
        super().__init__()
        self.bs = bs
        self.train_df = train_df
        self.val_df = val_df
        self.train_ds = mri_data(self.train_df)
        self.val_ds = mri_data(self.val_df)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            shuffle=True,
            num_workers=8,
            batch_size=self.bs
        )

    def val_dataloader(self):
        """
        validation dataloader loader
        batch size = train batch size x 2
        """
        return DataLoader(
            self.val_ds,
            shuffle=False,
            num_workers=8,
            batch_size=self.bs * 2
        )

Lightning Module

In [22]:
class PlMRIModel(pl.LightningModule):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.softmax = nn.Softmax(dim=-1)
        self.crit = nn.CrossEntropyLoss()
        self.accuracy_f = pl.metrics.Accuracy()

    def forward(self, x):
        return self.base(x)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.base.parameters(), lr=1e-4)
    
    def calc_all_metrics(
        self,
        y_, y, is_train
    ):
        phase = "train" if is_train else "val"
        
        logits = self.softmax(y_)
        acc = self.accuracy_f(logits, y)
        
        self.log(f'{phase}_acc', acc)
        
    def training_step(self, batch, batch_idx):
        x,y = batch
        x = x.float(); y=y.long()
        y_ = self(x)
        loss = self.crit(y_, y)
        
        self.log('train_loss', loss)
        self.calc_all_metrics(y_, y, True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x,y = batch
        x = x.float(); y=y.long()
        y_ = self(x)
        loss = self.crit(y_, y)

        self.log('val_loss', loss)
        self.calc_all_metrics(y_, y, False)
        
        return loss
In [23]:
pl_data = PlData(train_df, val_df, bs=8)
pl_model = PlMRIModel(model)

Training configuration

Logging and callbacks

In [24]:
# loggers
logger = pl.loggers.TensorBoardLogger("/GCI/tensorboard/brain_mri", name="cls")

# callbacks
early = pl.callbacks.EarlyStopping(monitor="val_acc")
saving = pl.callbacks.ModelCheckpoint(str(WEIGHTS/"cls_models"), monitor="val_acc", save_top_k = 3, mode="max")
/anaconda3/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Checkpoint directory /nvme/GCI/brain_mri/weights exists and is not empty.
  warnings.warn(*args, **kwargs)
In [25]:
trainer = pl.Trainer(
    logger=logger,
    callbacks=[early, saving],
    checkpoint_callback=True,
    gpus=1,
    fast_dev_run=False,
)
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

Some progressbar to watch

In [26]:
trainer.fit(pl_model,pl_data)
  | Name       | Type             | Params
------------------------------------------------
0 | base       | EfficientNet     | 28.3 M
1 | softmax    | Softmax          | 0     
2 | crit       | CrossEntropyLoss | 0     
3 | accuracy_f | Accuracy         | 0     
------------------------------------------------
28.3 M    Trainable params
0         Non-trainable params
28.3 M    Total params

Out[26]:
1

Visualize Prediction

Switch to evaluation mode

In [30]:
pl_model = pl_model.eval()

An array to array pipeline

In [44]:
from forgebox.ftorch.cuda import CudaHandler
In [39]:
cu = CudaHandler()
dev = cu.idle()()
>>> 2 cuda devices found >>>
Device 0: 
	name:Tesla V100-PCIE-32GB
	used:1497MB	free:31013MB
Device 1: 
	name:Tesla V100-PCIE-32GB
	used:19033MB	free:13477MB
cuda stats refreshed
Found the most idle GPU: cuda:0, 31013 MB Mem remained
In [41]:
pl_model = pl_model.to(dev)
In [42]:
def pred(x: np.array) -> float:
    """
    predict classification probabilities from image array
    """
    with torch.no_grad():
        return pl_model.softmax(
            pl_model(torch.FloatTensor(x).to(dev))).cpu().detach().numpy()[0]

View the validation prediction interactively

In [43]:
@interact
def see_val(idx = (0,len(val_ds))):
    x,y = val_ds[idx]
    print(f"Prediction, {pred(x[None,:])}, label {np.eye(3)[int(y)]}")
    plt.imshow(x[0])
In [45]:
preds = []
labels = []

for i in tqdm(range(len(val_ds))):
    x,y = val_ds[i]
    preds.append(pred(x[None, :]))
    labels.append(y)
pred_arr = np.stack(preds)

In [198]:
pred_df = pd.DataFrame(dict(
    idx=range(len(val_ds)),
    meningioma=pred_arr[:,0],
    glioma=pred_arr[:,1],
    pituitary=pred_arr[:,2],
    pred_idx=pred_arr.argmax(-1),
    labels=labels
))
In [53]:
from forgebox.images import widgets
In [79]:
def full_getter(cls,idx):
    row = dict(cls.df.loc[idx])
    img = np.load(row["img"])
    mask = np.load(row["mask"])
    img = img/(img.max())
    return img[None, ...], mask, row['label']-1
mri_data.full_getter = full_getter
In [177]:
def create_img(i):
    x, y, z = val_ds.full_getter(i)
    img = x[0]
    y = y.astype(np.float32)
    img_arr = np.stack([np.zeros_like(img),img,img],axis=-1)
    img_arr2 = np.stack([y*.5,img,img],axis=-1)
    img_all = np.concatenate([img_arr, img_arr2], axis=1)
    return Image.fromarray((img_all*256).astype(np.byte), mode="RGB")
In [185]:
cancer_types = ["meningioma","glioma", "pituitary"]
In [213]:
def view_top(cancer_type):
    """
    the top confident ones
    """
    top_ = pred_df.sort_values(by=cancer_type, ascending=False).head(12)
    img_list = list(create_img(i) for i in top_.idx)
    display(top_)
    widgets.view_images(*img_list, num_per_row=2)()


def view_top_error(cancer_type):
    """
    the top error ones
    """
    top_ = pred_df.query(f"pred_idx!=labels")\
        .sort_values(by=cancer_type, ascending=False).head(12)
    img_list = list(create_img(i) for i in top_.idx)
    display(top_)
    widgets.view_images(*img_list, num_per_row=2)()

Top confidence

meningioma

In [211]:
_ = view_top("meningioma")
idx meningioma glioma pituitary pred_idx labels
8 8 0.999990 0.000003 0.000007 0 0.0
190 190 0.999977 0.000005 0.000019 0 0.0
58 58 0.999974 0.000005 0.000021 0 0.0
60 60 0.999970 0.000012 0.000017 0 0.0
11 11 0.999969 0.000010 0.000021 0 0.0
87 87 0.999968 0.000015 0.000018 0 0.0
77 77 0.999966 0.000023 0.000011 0 0.0
141 141 0.999966 0.000019 0.000015 0 0.0
100 100 0.999966 0.000022 0.000012 0 0.0
149 149 0.999965 0.000008 0.000027 0 0.0
123 123 0.999961 0.000025 0.000014 0 0.0
124 124 0.999961 0.000010 0.000029 0 0.0

giloma

In [201]:
_ = view_top("glioma")
idx meningioma glioma pituitary pred_idx labels
526 526 7.010519e-09 1.0 5.764930e-15 1 1.0
630 630 7.059064e-09 1.0 1.108315e-14 1 1.0
545 545 3.636871e-09 1.0 5.988802e-13 1 1.0
546 546 1.316429e-12 1.0 2.641802e-17 1 1.0
547 547 4.953895e-11 1.0 1.852250e-16 1 1.0
553 553 3.441300e-08 1.0 6.598324e-12 1 1.0
561 561 1.924005e-08 1.0 4.247263e-09 1 1.0
564 564 2.145240e-09 1.0 9.915583e-11 1 1.0
575 575 2.261232e-08 1.0 1.533782e-11 1 1.0
587 587 5.237261e-10 1.0 2.430004e-13 1 1.0
588 588 3.173781e-09 1.0 7.127144e-12 1 1.0
543 543 3.027300e-09 1.0 1.637493e-10 1 1.0

pituitary

In [202]:
_ = view_top("pituitary")
idx meningioma glioma pituitary pred_idx labels
422 422 0.000008 0.000007 0.999984 2 2.0
448 448 0.000007 0.000014 0.999980 2 2.0
424 424 0.000014 0.000010 0.999977 2 2.0
453 453 0.000009 0.000016 0.999975 2 2.0
297 297 0.000029 0.000005 0.999966 2 2.0
413 413 0.000032 0.000020 0.999948 2 2.0
418 418 0.000033 0.000029 0.999938 2 2.0
469 469 0.000051 0.000023 0.999926 2 2.0
294 294 0.000041 0.000041 0.999918 2 2.0
447 447 0.000054 0.000042 0.999904 2 2.0
454 454 0.000041 0.000091 0.999869 2 2.0
465 465 0.000126 0.000013 0.999861 2 2.0

Top errors

In [214]:
_ = view_top_error("meningioma")
idx meningioma glioma pituitary pred_idx labels
255 255 0.999887 0.000089 0.000024 0 2.0
254 254 0.999295 0.000494 0.000211 0 2.0
216 216 0.999218 0.000741 0.000041 0 1.0
364 364 0.998501 0.000522 0.000978 0 2.0
432 432 0.997012 0.000648 0.002340 0 2.0
231 231 0.995858 0.003967 0.000175 0 1.0
281 281 0.995844 0.000980 0.003177 0 2.0
282 282 0.994215 0.002013 0.003771 0 2.0
341 341 0.994157 0.000441 0.005402 0 2.0
363 363 0.993935 0.000335 0.005731 0 2.0
342 342 0.990529 0.000298 0.009172 0 2.0
410 410 0.986810 0.000386 0.012803 0 2.0
In [208]:
_ = view_top_error("glioma")
idx meningioma glioma pituitary pred_idx labels
3 3 0.005896 0.980409 0.013695 1 0.0
279 279 0.017606 0.938792 0.043602 1 2.0
283 283 0.077100 0.900265 0.022635 1 2.0
168 168 0.124373 0.875516 0.000110 1 0.0
101 101 0.385000 0.502154 0.112846 1 0.0
613 613 0.565242 0.434733 0.000025 0 1.0
309 309 0.457543 0.345044 0.197413 0 2.0
280 280 0.625894 0.342799 0.031307 0 2.0
235 235 0.740706 0.258345 0.000949 0 1.0
204 204 0.809828 0.189969 0.000203 0 1.0
236 236 0.815339 0.182664 0.001998 0 1.0
618 618 0.848628 0.150943 0.000429 0 1.0
In [209]:
_ = view_top_error("pituitary")
idx meningioma glioma pituitary pred_idx labels
360 360 0.519692 0.000564 0.479744 0 2.0
443 443 0.528922 0.000292 0.470786 0 2.0
307 307 0.468920 0.075255 0.455825 0 2.0
328 328 0.556394 0.002041 0.441565 0 2.0
305 305 0.554723 0.005171 0.440106 0 2.0
198 198 0.577231 0.001177 0.421592 0 2.0
333 333 0.579510 0.000224 0.420265 0 2.0
306 306 0.600709 0.003719 0.395572 0 2.0
272 272 0.638502 0.003873 0.357625 0 2.0
262 262 0.667299 0.004909 0.327793 0 2.0
462 462 0.718708 0.001291 0.280001 0 2.0
444 444 0.727803 0.000174 0.272024 0 2.0
In [ ]: